# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A lightweight one-file FSDP SFT Trainer
TODO(zhangchi.usc1992)
- Add calculation of mfu
- Add validation
"""

import os

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import logging
import re
from contextlib import nullcontext

import hydra
import torch
import torch.distributed
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from peft import LoraConfig, TaskType, get_peft_model
from tensordict import TensorDict
from torch import nn, optim
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel

import verl.utils.hdfs_io as hdfs_io
from verl.utils.dataset import SFTDataset
# from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.distributed import initialize_global_process_group
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from verl.utils.tracking import Tracking
from verl.utils.ulysses import (
    gather_outpus_and_unpad,
    get_ulysses_sequence_parallel_world_size,
    ulysses_pad_and_slice_inputs,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from reil.utils.dataset.rg_dataset import prepare_reasoning_gym_sft_dataset
from reil.trainer.llm_agent.agent_proxy import LLMAgentProxy, HFWrapperWg
from typing import Dict, Any
from verl import DataProto
import numpy as np
from verl.workers.rollout.hf_rollout import HFRollout
from reasoning_gym.utils import extract_answer
from reil.trainer.main_ppo import get_custom_reward_fn
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from torchdata.stateful_dataloader import StatefulDataLoader

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN"))


def extract_step(path):
    match = re.search(r"global_step_(\d+)", path)
    if match:
        return int(match.group(1))
    return None


def convert_to_regular_types(obj):
    """Convert Hydra configs and other special types to regular Python types."""
    from omegaconf import DictConfig, ListConfig

    if isinstance(obj, (ListConfig, DictConfig)):
        return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
    elif isinstance(obj, (list, tuple)):
        return [convert_to_regular_types(x) for x in obj]
    elif isinstance(obj, dict):
        return {k: convert_to_regular_types(v) for k, v in obj.items()}
    return obj

class Config:
	def __init__(self, **kwargs):
		for key, value in kwargs.items():
			setattr(self, key, value)
	
	def get(self, key: str, default: Any = None) -> Any:
		return getattr(self, key, default)

class FSDPSFTTrainer:
    def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset):
        self.config = config
        self.device_mesh = device_mesh
        self.ulysses_device_mesh = ulysses_device_mesh
        self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
        self.tokenizer = tokenizer
        # if self.config.data.chat_template:
        #     raise ValueError("Apply Chat template from config is not supported yet.")

        # normalize dp size
        self._normalize_config_bsz()

        # Set sequence parallel size
        self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1)
        self.use_remove_padding = getattr(self.config, "use_remove_padding", False)
        if self.device_mesh.get_rank() == 0:
            print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
            print(f"Using remove padding: {self.use_remove_padding}")

        if self.config.data.get('val_score_files', None):
            self.val_score_dataset = RLHFDataset(parquet_files=config.data.val_score_files,
                                        tokenizer=tokenizer,
                                        prompt_key=config.data.prompt_key,
                                        image_key=config.data.get('image_key', 'images'),
                                        max_prompt_length=config.data.max_prompt_length,
                                        chat_template=config.data.get('chat_template', False),
                                        filter_prompts=True,
                                        return_raw_chat=config.data.get('return_raw_chat', False),
                                        truncation=config.data.get('truncation', 'error'),
                                        filter_overlong_prompts=config.data.filter_overlong_prompts)
        else:
            self.val_score_dataset = None
        self._build_dataloader(train_dataset, val_dataset)

        # build model
        self._build_model_optimizer()

        
        # TODO: add checkpoint manager
        if self.device_mesh.get_rank() == 0:
            print(self.config)
        self.eval_model: PreTrainedModel | None = None
        self.ref_model: PreTrainedModel | None = None
        self.fsdp_ref_model: FSDP | None = None
        # Build or refresh eval model with participation from all ranks to avoid FSDP deadlocks
        if self.config.data.type != 'reasoning_gym' and self.config.trainer.policy_eval:
            self._sync_eval_model_from_fsdp_all_ranks()
            if self.device_mesh.get_rank() == 0:
                self.init_agent_proxy()
        self.init_rollout()
        self.init_reward_function()
        if not hasattr(self, 'val_reward_fn') or self.val_reward_fn is None:
            print("No reward function is initialized")
        # Cache a standalone eval model to avoid rebuilding
        
        # Optionally build frozen reference model for KL regularization (non-SP only)
        trainer_cfg = getattr(self.config, 'trainer', None)
        kl_cfg = getattr(trainer_cfg, 'kl_regularization', None) if trainer_cfg is not None else None
        if kl_cfg is not None and getattr(kl_cfg, 'enabled', False):
            if self.config.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
                raise NotImplementedError("KL regularization with sequence parallel/remove-padding path is not supported yet")
            self._build_ref_model()

    def _normalize_config_bsz(self):
        dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
        if self.device_mesh.get_rank() == 0:
            print(f"Normalize batch size by dp {dp_size}")

        assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"

        self.config.data.train_batch_size //= dp_size

        assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0

    def _build_dataloader(self, train_dataset, val_dataset):
        # build dataset
        config = self.config
        self.train_dataset, self.val_dataset = train_dataset, val_dataset

        # build dataloader
        # Use data parallel rank and size instead of global rank and world size

        # If doing SP, we need to use the local rank and size
        if self.config.ulysses_sequence_parallel_size > 1:
            rank = self.ulysses_device_mesh.get_local_rank("dp")
            world_size = self.ulysses_device_mesh.size(0)
            if self.ulysses_device_mesh.get_rank() == 0:
                print(f"Using SP rank {rank} and size {world_size} for data distribution")
                print("Each SP rank gets different data, but the same data WITHIN the same rank")
        else:
            rank = self.device_mesh.get_rank()
            world_size = self.device_mesh.size()
        if self.device_mesh.get_rank() == 0:
            print(f"Using FSDP rank {rank} and size {world_size} for data distribution")

        self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True)
        self.train_dataloader = DataLoader(
            dataset=self.train_dataset,
            batch_size=config.data.train_batch_size,
            sampler=self.train_sampler,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
        )

        self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True)
        self.val_dataloader = DataLoader(
            dataset=self.val_dataset,
            batch_size=config.data.micro_batch_size_per_gpu,
            # batch_size=len(self.val_dataset)//world_size,
            sampler=self.val_sampler,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
        )
        print(f"Validation dataset length: {len(self.val_dataset)}")
        print(f"Validation dataloader length: {len(self.val_dataloader)}")
        # assert len(self.val_dataloader) == 1, "Validation dataloader should have only one batch"
        if self.val_score_dataset is not None:
            self.val_score_dataloader = StatefulDataLoader(
            dataset=self.val_score_dataset,
            # Validation datasets are sent to inference engines as a whole batch,
            # which will schedule the memory themselves.
            batch_size=len(self.val_score_dataset),
            num_workers=8,
            shuffle=True,
            drop_last=False,
            collate_fn=collate_fn)

    def init_reward_function(self):
        if self.config.data.type == 'reasoning_gym':
            self.reward_fn = lambda data: self._score_output(data, num_examine=0, is_val=False)
            self.val_reward_fn = lambda data: self._score_output(data, num_examine=1, is_val=True)
        else:
            # TODO: add reward function for standard SFT
            reward_manager_name = self.config.reward_model.get("reward_manager", "naive")
            if reward_manager_name == 'naive':
                from verl.workers.reward_manager import NaiveRewardManager
                reward_manager_cls = NaiveRewardManager
            elif reward_manager_name == 'prime':
                from verl.workers.reward_manager import PrimeRewardManager
                reward_manager_cls = PrimeRewardManager
            elif reward_manager_name == 'complete':
                from reil.workers.reward_manager import CompleteRewardManager
                reward_manager_cls = CompleteRewardManager
            elif reward_manager_name == 'gp_l':
                from reil.workers.reward_manager import GPLRewardManager
                reward_manager_cls = GPLRewardManager
            else:
                raise NotImplementedError
            compute_score = get_custom_reward_fn(self.config)
            self.reward_fn = reward_manager_cls(tokenizer=self.tokenizer, num_examine=0, compute_score=compute_score)
            self.val_reward_fn = reward_manager_cls(tokenizer=self.tokenizer, num_examine=1, compute_score=compute_score)

    def _score_output(self, data: DataProto, num_examine: int = 0, is_val: bool = False) -> torch.Tensor:
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)

        num_printed = 0
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem
            # print(data_item.non_tensor_batch)
            prompt_ids = data_item.batch["prompts"]  # tokenized prompts
            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids)
            response_str = self.tokenizer.decode(valid_response_ids)
            sequences_str = prompt_str + response_str

            index = data_item.non_tensor_batch["index"]
            correctness_score = self._compute_correctness_score(
                solution_str=response_str,
                index=index,
                is_val=is_val
            )
            if self.config.reward.use_accuracy:
                reward_components = {"correctness": correctness_score}
                total_reward = correctness_score
            else:
                reward_components = {}
                total_reward = 0

            reward_tensor[i, valid_response_length - 1] = total_reward

            if num_printed < num_examine:
                components = ", ".join([f"{k}={v:.2f}" for k, v in reward_components.items()])
                print(f"(score={total_reward}, seq={sequences_str}, response={response_str})")
                print(f"reward={total_reward:.2f} ({components})")
                num_printed += 1

        return reward_tensor

    def _compute_correctness_score(self, solution_str: str, index: int, is_val: bool = False) -> float:
        found_answer = extract_answer(solution_str, tag_name="answer")
        if is_val:
            data = self.val_dataset.data
        else:
            data = self.train_dataset.data

        entry = data[index]
        if is_val:
            if self.val_dataset.experiment:
                experiment = self.val_dataset.experiment
                return experiment.score_answer_with_id(found_answer, entry["metadata"]["entry_id"])
            else:
                return data.score_answer(found_answer, entry=entry)
        else:
            if self.train_dataset.experiment:   
                experiment = self.train_dataset.experiment
                return experiment.score_answer_with_id(found_answer, entry["metadata"]["entry_id"])
            else:
                return data.score_answer(found_answer, entry=entry)

    def _build_model_optimizer(self):
        # TODO (zhangchi.usc1992):
        # 1. support pretrain from random weights
        # 2. support init directly from sharded weights
        self.local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)

        if self.config.model.get("external_lib", None) is not None:
            # This is used to import external_lib into the huggingface systems
            import importlib

            importlib.import_module(self.config.model.external_lib)

        log_gpu_memory_usage("Before model allocation", logger=logger)

        trust_remote_code = self.config.model.trust_remote_code
        # load config first
        config = AutoConfig.from_pretrained(self.local_model_path, trust_remote_code=trust_remote_code)
        if self.config.ulysses_sequence_parallel_size > 1:
            assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"

        # This may be very large
        init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh)

        with init_context():
            self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
                self.local_model_path,
                config=config,
                torch_dtype=torch.float32,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

            if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
                from verl.models.transformers.monkey_patch import apply_monkey_patch

                apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)

            # Apply Liger kernel if use_liger is enabled
            if self.config.model.get("use_liger", False):
                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance

                _apply_liger_kernel_to_instance(model=self.model)

            if self.config.model.get("lora_rank", 0) > 0:
                self.model.enable_input_require_grads()
                # Convert config to regular Python types before creating PEFT model
                lora_config = {
                    "task_type": TaskType.CAUSAL_LM,
                    "r": self.config.model.lora_rank,
                    "lora_alpha": self.config.model.lora_alpha,
                    "target_modules": convert_to_regular_types(self.config.model.target_modules),
                    "bias": "none",
                }
                self.model = get_peft_model(self.model, LoraConfig(**lora_config))

        if self.config.model.enable_gradient_checkpointing:
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

        log_gpu_memory_usage("After model allocation", logger=logger)

        mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

        auto_wrap_policy = get_fsdp_wrap_policy(
            self.model,
            config=self.config.model.fsdp_config.wrap_policy,
            is_lora=self.config.model.get("lora_rank", 0) > 0,
        )
        if self.device_mesh.get_rank() == 0:
            print(auto_wrap_policy)

        if not self.config.model.fsdp_config.cpu_offload:
            cpu_offload = None
        else:
            cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)

        self.fsdp_model = FSDP(
            module=self.model,
            auto_wrap_policy=auto_wrap_policy,
            param_init_fn=init_fn,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            mixed_precision=mixed_precision,
            device_mesh=self.device_mesh,
            sync_module_states=True,
            device_id=torch.cuda.current_device(),
            cpu_offload=cpu_offload,
            use_orig_params=False,
        )

        log_gpu_memory_usage("After FSDP wrapping", logger=logger)

        anchor_cfg = getattr(getattr(self.config, 'trainer', None), 'anchor_regularization', None)
        self.anchor_enabled = bool(anchor_cfg is not None and getattr(anchor_cfg, 'enabled', False))
        if self.anchor_enabled:
            self.anchor_coeff = float(getattr(anchor_cfg, 'l2_anchor_coeff', 0.0))
            self.anchor_snapshots: list[torch.Tensor] = []
            with torch.no_grad():
                for p in self.fsdp_model.parameters():
                    if not p.requires_grad:
                        continue
                    self.anchor_snapshots.append(p.detach().clone())  # same device/dtype

        self.optimizer = optim.AdamW(
            self.fsdp_model.parameters(),
            lr=self.config.optim.lr,
            betas=self.config.optim.betas,
            weight_decay=self.config.optim.weight_decay,
        )

        log_gpu_memory_usage("After initialize optimizer", logger=logger)

        self.steps_per_epoch = len(self.train_dataloader)
        self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs

        if self.device_mesh.get_rank() == 0:
            print(f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}")

        num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)

        if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine":
            self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps)
        # elif self.config.optim.lr_scheduler == "wsd":
        #     self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps)
        elif self.config.optim.lr_scheduler == "constant":
            num_warmup_steps = 0
            self.lr_scheduler = get_constant_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps)
        else:
            raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}")

    def _build_ref_model(self) -> None:
        """Build a frozen, FSDP-sharded reference model loaded from model.partial_pretrain.

        The reference model is eval-only, has no LoRA/adapters, and is not used for backward.
        It mirrors the student's config and attention implementation for exactness.
        """
        assert self.ref_model is None and self.fsdp_ref_model is None, "Reference model already initialized"

        trust_remote_code = self.config.model.trust_remote_code
        config = self.model.config

        log_gpu_memory_usage("Before ref model allocation", logger=logger)

        # Build reference model weights
        init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh)
        with init_context():
            ref_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
                self.local_model_path,
                config=config,
                torch_dtype=torch.float32,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

        # Do NOT apply LoRA or gradient checkpointing to ref model
        ref_model.eval()
        for p in ref_model.parameters():
            p.requires_grad_(False)

        mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
        auto_wrap_policy = get_fsdp_wrap_policy(
            ref_model,
            config=self.config.model.fsdp_config.wrap_policy,
            is_lora=False,
        )

        if not self.config.model.fsdp_config.cpu_offload:
            cpu_offload = None
        else:
            cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)

        fsdp_ref_model = FSDP(
            module=ref_model,
            auto_wrap_policy=auto_wrap_policy,
            param_init_fn=init_fn,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            mixed_precision=mixed_precision,
            device_mesh=self.device_mesh,
            sync_module_states=True,
            device_id=torch.cuda.current_device(),
            cpu_offload=cpu_offload,
            use_orig_params=False,
        )

        fsdp_ref_model.eval()
        for p in fsdp_ref_model.parameters():
            p.requires_grad_(False)

        log_gpu_memory_usage("After ref FSDP wrapping", logger=logger)

        self.ref_model = ref_model
        self.fsdp_ref_model = fsdp_ref_model

    def _compute_ref_logp(self, batch: TensorDict) -> torch.Tensor:
        """Compute reference log-probabilities for shifted tokens and return a pinned CPU tensor.

        Returns a tensor shaped [batch, seq_len-1, vocab_size], dtype bfloat16, on CPU (pinned).
        """
        assert self.fsdp_ref_model is not None, "Reference model is not initialized"

        device = torch.cuda.current_device()
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        position_ids = batch["position_ids"].to(device, non_blocking=True)

        batch_size = input_ids.size(0)
        micro_bs = int(getattr(self.config.data, "micro_batch_size_per_gpu", batch_size))
        chunks_cpu: list[torch.Tensor] = []

        self.fsdp_ref_model.eval()
        for start in range(0, batch_size, micro_bs):
            end = min(start + micro_bs, batch_size)
            ids = input_ids[start:end]
            am = attention_mask[start:end]
            pos = position_ids[start:end]
            with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = self.fsdp_ref_model(
                    input_ids=ids,
                    attention_mask=am,
                    position_ids=pos,
                    use_cache=False,
                ).logits  # [mb, S, V]
            shift_logits = logits[..., :-1, :].contiguous()  # [mb, S-1, V]
            logp_chunk = torch.log_softmax(shift_logits.float(), dim=-1).to(torch.bfloat16)
            chunks_cpu.append(logp_chunk.cpu().pin_memory())
            del ids, am, pos, logits, shift_logits, logp_chunk

        ref_logp_cpu = torch.cat(chunks_cpu, dim=0)
        return ref_logp_cpu

    def init_agent_proxy(self):
        # assert self.config.trainer.policy_eval==False, "Policy eval must be disabled for Reasoning Gym"
        tokenizer = self.tokenizer
        config = self.config
        actor_wg = HFWrapperWg(config, tokenizer, module=self._get_or_build_eval_model())
        self.proxy = LLMAgentProxy(config, actor_wg, tokenizer)

    def _get_or_build_eval_model(self) -> PreTrainedModel:
        if self.eval_model is None:
            self.eval_model = self._build_eval_model_from_fsdp()
        return self.eval_model

    def _gather_full_state_dict_all_ranks(self):
        from torch.distributed.fsdp import StateDictType, FullStateDictConfig
        cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
        with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
            full_state_dict = self.fsdp_model.state_dict()
        return full_state_dict if self.device_mesh.get_rank() == 0 else None

    def _build_eval_model_from_fsdp(self) -> PreTrainedModel:
        """
        Materialize a full-precision eval model with the current FSDP weights, without FSDP wrapping.
        Avoids FSDP all-gather during generation and preserves exactness.
        """
        full_state_dict = self._gather_full_state_dict_all_ranks()
        assert self.device_mesh.get_rank() == 0, "_build_eval_model_from_fsdp should be called only on rank 0"

        eval_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
            self.local_model_path,
            config=self.model.config,
            torch_dtype=torch.bfloat16 if self.model.config.torch_dtype == torch.bfloat16 else torch.float32,
            attn_implementation="flash_attention_2",
            trust_remote_code=self.config.model.trust_remote_code,
        ).cuda()
        missing, unexpected = eval_model.load_state_dict(full_state_dict, strict=False)
        if (missing or unexpected):
            print(f"[Eval Model] Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
        eval_model.eval()
        return eval_model

    def _sync_eval_model_from_fsdp_all_ranks(self) -> None:
        """All ranks participate in gathering the full state dict; rank 0 builds or refreshes eval model.

        This avoids deadlocks caused by calling FSDP full-state-dict collectives on a subset of ranks.
        """
        full_state_dict = self._gather_full_state_dict_all_ranks()
        rank = self.device_mesh.get_rank()
        if rank == 0:
            if self.eval_model is None:
                self.eval_model = AutoModelForCausalLM.from_pretrained(
                    self.local_model_path,
                    config=self.model.config,
                    torch_dtype=torch.bfloat16 if self.model.config.torch_dtype == torch.bfloat16 else torch.float32,
                    attn_implementation="flash_attention_2",
                    trust_remote_code=self.config.model.trust_remote_code,
                ).cuda()
            missing, unexpected = self.eval_model.load_state_dict(full_state_dict, strict=False)
            if (missing or unexpected):
                print(f"[Eval Model Sync] Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
            self.eval_model.eval()
        # Ensure all ranks wait until rank 0 finishes preparing the eval model
        torch.distributed.barrier()

    def _policy_eval_and_log(self, tracking: Tracking, global_step: int) -> None:
        """Run policy evaluation on rank 0 and log metrics. Assumes eval model has been synchronized."""
        assert self.device_mesh.get_rank() == 0
        actor_wg = HFWrapperWg(self.config, self.tokenizer, module=self._get_or_build_eval_model())
        self.proxy.set_actor_wg(actor_wg)
        rollouts = self.proxy.rollout()
        tracking.log(data=rollouts.meta_info['metrics'], step=global_step)

    def _validate(self):
        reward_tensor_lst = []
        data_source_lst = []

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        for test_data in self.val_score_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n,
                                           interleave=True)

            # Store original inputs
            input_ids = test_batch.batch['input_ids']
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys():
                test_gen_batch = test_batch.pop(
                    batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                    non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
                )
            else:
                test_gen_batch = test_batch.pop(
                    batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                    non_tensor_batch_keys=['raw_prompt_ids'],
                )

            test_gen_batch.meta_info = {
                'eos_token_id': self.tokenizer.eos_token_id,
                'pad_token_id': self.tokenizer.pad_token_id,
                'recompute_log_prob': False,
                'do_sample': self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                'validate': True,
            }
            print(f'test_gen_batch meta info: {test_gen_batch.meta_info}')

            # Use non-FSDP model for generation to avoid all-gather
            self.model.eval()
            test_output_gen_batch = self.rollout.generate_sequences(test_gen_batch)

            # unpad
            # test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print('validation generation end')

            # Store generated outputs
            output_ids = test_output_gen_batch.batch['responses']
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # evaluate using reward_function
            reward_tensor = self.val_reward_fn(test_batch)

            # Store scores
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_tensor_lst.append(reward_tensor)
            data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))

        # self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu()  # (batch_size,)
        data_sources = np.concatenate(data_source_lst, axis=0)

        # evaluate test_score based on data source
        data_source_reward = {}
        for i in range(reward_tensor.shape[0]):
            data_source = data_sources[i]
            if data_source not in data_source_reward:
                data_source_reward[data_source] = []
            data_source_reward[data_source].append(reward_tensor[i].item())

        metric_dict = {}
        for data_source, rewards in data_source_reward.items():
            metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)

        return metric_dict

    def init_rollout(self):
        rollout_config = Config(
            # micro_batch_size=self.config.data.micro_batch_size_per_gpu,
            micro_batch_size=len(self.val_dataset),
            response_length=self.config.actor_rollout_ref.rollout.response_length,
            do_sample=self.config.actor_rollout_ref.rollout.do_sample,
            temperature=self.config.actor_rollout_ref.rollout.val_kwargs.temperature,
            top_p=self.config.actor_rollout_ref.rollout.val_kwargs.top_p,
            top_k=self.config.actor_rollout_ref.rollout.val_kwargs.top_k,
        )
        self.rollout = HFRollout(self.model, rollout_config)

    def _compute_loss_and_backward(self, batch, do_backward=True, ref_logp_mb: torch.Tensor | None = None):
        """Compute loss with optional sequence parallelism and remove padding features"""
        use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1
        if use_sp:
            assert self.config.trainer.sft_type == "standard", "Sequence parallel is only supported for standard SFT"

        # Move inputs to GPU and prepare loss mask
        input_ids = batch["input_ids"].cuda()
        attention_mask = batch["attention_mask"].cuda()
        position_ids = batch["position_ids"].cuda()
        # Keep both 2D and flattened masks for CE and KL
        loss_mask_2d = batch.get("loss_mask")[:, :-1].cuda()
        loss_mask = loss_mask_2d.reshape(-1)
        batch.pop("loss_mask", None)
        loss_fct = nn.CrossEntropyLoss(reduction="none")

        # Context manager for sequence parallel if needed
        context = self.sharding_manager if use_sp else nullcontext()
        with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            if not use_sp:
                # Standard forward pass without sequence parallel
                labels = input_ids[:, 1:].contiguous()
                output = self.fsdp_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
                logits = output.logits

                # Shifted logits for next-token prediction
                shift_logits = logits[..., :-1, :].contiguous()
                # Preserve 3D shifted logits for KL; keep a separate variable for clarity
                shift_logits_3d = shift_logits
                shift_labels = labels.contiguous()
                # Flatten the tokens
                shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)
                ce_loss = loss.clone()
                # Optional: Forward KL regularization with frozen reference
                kl_loss = None
                if ref_logp_mb is not None:
                    # model log-probs and probs from shifted logits
                    log_p_s = torch.log_softmax(shift_logits_3d.float(), dim=-1)
                    p_s = torch.exp(log_p_s)
                    # match dtype and device for ref
                    ref_logp_mb = ref_logp_mb.to(device=log_p_s.device, dtype=log_p_s.dtype, non_blocking=True)
                    # per-token KL over vocab
                    kl_per_token = torch.sum(p_s * (log_p_s - ref_logp_mb), dim=-1)  # [B, S-1]
                    kl_per_token = kl_per_token.reshape(-1)
                    # apply mask and reduce like CE
                    kl_per_token = kl_per_token * loss_mask.to(kl_per_token.device)
                    valid_token_this_rank = torch.sum(loss_mask)
                    if self.config.data.balance_dp_token:
                        torch.distributed.all_reduce(valid_token_this_rank)
                        dp_size = torch.distributed.get_world_size()
                    else:
                        dp_size = 1
                    kl_loss = torch.sum(kl_per_token) / (valid_token_this_rank + 1e-8) * dp_size
                if self.config.trainer.sft_type == "dft":
                    probs = torch.softmax(shift_logits, dim=-1)
                    prob_coefficients = probs.gather(1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    loss = loss * ( 1e-4 + prob_coefficients.detach())
                    
                elif self.config.trainer.sft_type == "aft":
                    probs = torch.softmax(shift_logits, dim=-1)
                    prob_coefficients = probs.gather(1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    # Get the power parameter, default to 1 if not specified
                    aft_power = getattr(self.config.trainer, 'aft_power', 1.0)
                    loss = loss * torch.pow(1 - prob_coefficients.detach(), aft_power)
                elif self.config.trainer.sft_type == "standard":
                    pass
                else:
                    raise ValueError(f"Unknown SFT type: {self.config.trainer.sft_type}")
                loss = loss * loss_mask.to(loss.device)
                ce_loss = ce_loss * loss_mask.to(ce_loss.device)

                # Anchor L2 regularization
                if getattr(self, 'anchor_enabled', False):
                    # only compute on main microbatch path to avoid double-scaling across SP
                    reg_sum = torch.tensor(0.0, device=loss.device, dtype=torch.float32)
                    p_idx = 0
                    for p in self.fsdp_model.parameters():
                        if not p.requires_grad:
                            continue
                        base = self.anchor_snapshots[p_idx]
                        reg_sum = reg_sum + (p.float() - base.float()).pow(2).sum()
                        p_idx += 1
                    # use raw L2 sum without parameter-count normalization
                    reg = reg_sum
                    anchor_term = float(self.anchor_coeff) * reg
                    # also accumulate raw (unscaled) L2 sum for logging
                    self.anchor_step_l2 = self.anchor_step_l2 + reg.detach()
            else:
                # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks
                # i.e., each GPU has <1 sequence, and each SP group has 1 sequence
                # 1. All SP ranks will receive the *SAME* batch
                # 2. Different SP groups will receive *DIFFERENT* batches
                # This is implemented by the DistributedSampler

                batch_size, seqlen = input_ids.shape
                # Remove padding
                input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)  # input_ids_rmpad (total_nnz, ...)
                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

                # Unpad position_ids to align rotary
                position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)

                # Pad and slice inputs for sequence parallelism
                input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
                # For computing loss
                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)
                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size())
                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)

                # Forward pass
                output = self.fsdp_model(
                    input_ids=input_ids_rmpad_sliced,
                    attention_mask=None,  # Not needed with flash attention varlen
                    position_ids=position_ids_rmpad_padded,
                    use_cache=False,
                )

                # Compute loss locally then aggregate
                logits_rmpad = output.logits.squeeze(0)
                input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
                loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
                # Gather and unpad for sequence parallelism
                loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)

                # This is the loss collected from all ulysses ranks
                full_loss = pad_input(hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen)
                full_loss = full_loss.squeeze(-1)[:, :-1]  # Remove last token's loss
                full_loss = full_loss.reshape(-1)
                loss_mask = loss_mask.to(full_loss.device)
                loss = full_loss * loss_mask

            valid_token_this_rank = torch.sum(loss_mask)

            if self.config.data.balance_dp_token:
                torch.distributed.all_reduce(valid_token_this_rank)
                dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size()
            else:
                dp_size = 1

            loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size
            ce_loss = torch.sum(ce_loss) / (valid_token_this_rank + 1e-8) * dp_size
            # Add KL term if provided
            if ref_logp_mb is not None and kl_loss is not None:
                kl_coef = getattr(getattr(self.config.trainer, 'kl_regularization', None), 'kl_coef', 0.05)
                loss = loss + float(kl_coef) * kl_loss
            
            # Add anchor term if enabled
            if getattr(self, 'anchor_enabled', False):
                loss = loss + anchor_term
            
            if do_backward:
                loss.backward()
            # Return kl_loss (or 0.0) for logging convenience
            if ref_logp_mb is not None and kl_loss is not None:
                return loss, ce_loss, kl_loss
            else:
                return loss, ce_loss, torch.tensor(0.0, device=loss.device)

    def training_step(self, batch: TensorDict):
        self.fsdp_model.train()

        log_gpu_memory_usage("Before optimizer zero_grad", logger=logger)

        self.optimizer.zero_grad()

        log_gpu_memory_usage("After optimizer zero_grad", logger=logger)

        micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
        n_micro_batches = len(micro_batches)
        step_loss = 0
        step_ce_loss = 0
        step_kl_loss = 0
        if getattr(self, 'anchor_enabled', False):
            self.anchor_step_l2 = torch.tensor(0.0, device=torch.cuda.current_device(), dtype=torch.float32)
        # Prepare reference log-probs once per step if KL is enabled
        kl_cfg = getattr(getattr(self.config, 'trainer', None), 'kl_regularization', None)
        kl_enabled = kl_cfg is not None and getattr(kl_cfg, 'enabled', False)
        ref_logp_cpu = None
        if kl_enabled:
            ref_logp_cpu = self._compute_ref_logp(batch)
        offset = 0
        for micro_batch in micro_batches:
            mb_size = micro_batch.get("input_ids").shape[0]
            if kl_enabled:
                ref_logp_mb = ref_logp_cpu[offset:offset + mb_size].to(torch.cuda.current_device(), non_blocking=True)
            else:
                ref_logp_mb = None
            loss, ce_loss, kl_loss = self._compute_loss_and_backward(batch=micro_batch, ref_logp_mb=ref_logp_mb)
            loss = loss / n_micro_batches
            ce_loss = ce_loss / n_micro_batches
            kl_loss = kl_loss / n_micro_batches
            step_loss += loss.item()
            step_ce_loss += ce_loss.item()
            step_kl_loss += kl_loss.item()
            offset += mb_size

        grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)

        log_gpu_memory_usage("Before optimizer step", logger=logger)

        # if grad_norm is not finite, skip the update
        if not torch.isfinite(grad_norm):
            print(f"WARN: grad_norm is not finite: {grad_norm}")
            self.optimizer.zero_grad()
        else:
            self.optimizer.step()

        log_gpu_memory_usage("After optimizer step", logger=logger)

        self.lr_scheduler.step()

        # reduce loss across dp ranks
        lr = self.lr_scheduler.get_last_lr()[0]

        log_gpu_memory_usage("After offload weights", logger=logger)

        step_loss = torch.tensor(step_loss).cuda()
        step_ce_loss = torch.tensor(step_ce_loss).cuda()
        step_kl_loss = torch.tensor(step_kl_loss).cuda()
        torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
        torch.distributed.all_reduce(step_ce_loss, op=torch.distributed.ReduceOp.AVG)
        torch.distributed.all_reduce(step_kl_loss, op=torch.distributed.ReduceOp.AVG)
        metrics = {"train/loss": step_loss.detach().item(), "train/lr": lr, "train/ce_loss": step_ce_loss.detach().item(), "train/grad_norm": grad_norm.detach().item()}
        if kl_enabled:
            metrics["train/kl_loss"] = step_kl_loss.detach().item()
            metrics["train/kl_coef"] = float(getattr(kl_cfg, 'kl_coef', 0.05))
        if getattr(self, 'anchor_enabled', False):
            # Global raw L2 (sum of squared diffs across trainable params and ranks)
            anchor_l2 = self.anchor_step_l2 / n_micro_batches
            torch.distributed.all_reduce(anchor_l2, op=torch.distributed.ReduceOp.SUM)
            metrics["train/l2_from_base"] = float(anchor_l2.detach().item())
        return metrics

    def validation_step(self, batch: TensorDict):
        self.fsdp_model.eval()
        with torch.no_grad():
            # Compute KL if enabled for consistency with training loss
            kl_cfg = getattr(getattr(self.config, 'trainer', None), 'kl_regularization', None)
            kl_enabled = kl_cfg is not None and getattr(kl_cfg, 'enabled', False)
            if kl_enabled:
                ref_logp_cpu = self._compute_ref_logp(batch)
                ref_logp_mb = ref_logp_cpu.to(torch.cuda.current_device(), non_blocking=True)
            else:
                ref_logp_mb = None
            loss, ce_loss, kl_loss = self._compute_loss_and_backward(batch, do_backward=False, ref_logp_mb=ref_logp_mb)
            torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
            torch.distributed.all_reduce(ce_loss, op=torch.distributed.ReduceOp.AVG)
            torch.distributed.all_reduce(kl_loss, op=torch.distributed.ReduceOp.AVG)
        return loss, ce_loss, kl_loss

    def save_checkpoint(self, step):
        # save checkpoint
        from torch.distributed.fsdp import FullStateDictConfig, StateDictType

        cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
            state_dict = self.fsdp_model.state_dict()

        path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}")
        # save huggingface model
        if self.device_mesh.get_rank() == 0:
            os.makedirs(path, exist_ok=True)
            self.model.save_pretrained(path, state_dict=state_dict)
            self.tokenizer.save_pretrained(path)
            if self.config.trainer.default_hdfs_dir:
                hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
                hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
        torch.distributed.barrier()

    def fit(self):
        rank = self.device_mesh.get_rank()

        # TODO: add a unified tracking
        if rank == 0:
            tracking = Tracking(
                project_name=self.config.trainer.project_name,
                experiment_name=self.config.trainer.experiment_name,
                default_backend=self.config.trainer.logger,
            )

        global_step = 0
        # compute the total training steps.
        # the total training steps in SFT is mainly for early exit
        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

        if self.config.trainer.total_training_steps is not None:
            total_training_steps = self.config.trainer.total_training_steps

        self.total_training_steps = total_training_steps
        print(f"Total training steps: {self.total_training_steps}")

        # TODO (zhangchi.usc1992) add back checkpoint manager.
        # Currently, it blocks when uploading to hdfs. So very slow.
        if self.config.trainer.get('val_before_train', False):
            if self.val_score_dataset is not None:
                # single turn  
                self._validate()
            elif self.config.trainer.policy_eval:
                # All ranks must participate in FSDP full-state-dict collectives
                self._sync_eval_model_from_fsdp_all_ranks()
                if rank == 0:
                    self._policy_eval_and_log(tracking, global_step)
        # Ensure all ranks synchronize here; barrier must be called by every rank
        torch.distributed.barrier()

        for epoch in range(self.config.trainer.total_epochs):
            self.train_sampler.set_epoch(epoch=epoch)
            for data in tqdm(
                self.train_dataloader,
                total=self.steps_per_epoch,
                desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}",
            ):
                global_step += 1
                if not torch.cuda.is_available():
                    raise RuntimeError("CUDA is not available")
                if torch.cuda.current_device() != rank:
                    raise RuntimeError(f"Device mismatch: current={torch.cuda.current_device()}, expected={rank}")

                local_rank = int(os.environ["LOCAL_RANK"])

                data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda(device=local_rank)

                metric = self.training_step(data)
                if rank == 0:
                    tracking.log(data=metric, step=global_step)
                
                if self.config.trainer.save_freq != -1:
                    if global_step % self.config.trainer.save_freq == 0:
                        self.save_checkpoint(step=global_step)
                # for early exit validation
                if global_step >= self.total_training_steps:
                    # Perform final validation
                    val_losses = []
                    val_ce_losses = []
                    kl_enabled = getattr(getattr(self.config, 'trainer', None), 'kl_regularization', None)
                    kl_enabled = kl_enabled is not None and getattr(kl_enabled, 'enabled', False)
                    val_kl_losses = []
                    for val_data in self.val_dataloader:
                        val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda(device=local_rank)
                        val_loss, val_ce_loss, val_kl_loss = self.validation_step(val_data)
                        val_losses.append(val_loss)
                        val_ce_losses.append(val_ce_loss)
                        if kl_enabled:
                            val_kl_losses.append(val_kl_loss)
                    
                    if self.val_score_dataset is not None:
                        metric_dict = self._validate()

                    # if self.config.trainer.policy_eval and self.config.model.lora_rank == 0:
                    if self.config.trainer.policy_eval:
                        # All ranks refresh; rank 0 runs rollout
                        self._sync_eval_model_from_fsdp_all_ranks()
                        if rank == 0:
                            self._policy_eval_and_log(tracking, global_step)

                    if rank == 0:
                        avg_val_loss = torch.mean(torch.stack(val_losses))
                        avg_val_ce_loss = torch.mean(torch.stack(val_ce_losses))
                        metric = {"val/loss": avg_val_loss.detach().item(), "val/ce_loss": avg_val_ce_loss.detach().item()}
                        if kl_enabled and len(val_kl_losses) > 0:
                            avg_val_kl_loss = torch.mean(torch.stack(val_kl_losses))
                            metric["val/kl_loss"] = avg_val_kl_loss.detach().item()
                        tracking.log(data=metric, step=global_step)
                        if self.val_score_dataset is not None:
                            tracking.log(data=metric_dict, step=global_step)
                        if self.config.trainer.policy_eval: 
                            pass  # metrics logged in _policy_eval_and_log
                    
                    torch.distributed.barrier()

                    # Save final checkpoint
                    self.save_checkpoint(step=global_step)
                    return
                


            # validation
            val_losses = []
            val_ce_losses = []
            score = None
            kl_enabled = getattr(getattr(self.config, 'trainer', None), 'kl_regularization', None)
            kl_enabled = kl_enabled is not None and getattr(kl_enabled, 'enabled', False)
            val_kl_losses = []

            for data in self.val_dataloader:
                data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda(device=local_rank)
                # data = TensorDict(data, batch_size=len(self.val_dataset)//world_size).cuda(device=local_rank)
                val_loss, val_ce_loss, val_kl_loss = self.validation_step(data)
                val_losses.append(val_loss)
                val_ce_losses.append(val_ce_loss)
                if kl_enabled:
                    val_kl_losses.append(val_kl_loss)

            # After loop
            if self.val_score_dataset is not None:
                metric_dict = self._validate()
            
            if rank == 0:
                val_loss = torch.mean(torch.stack(val_losses))
                avg_val_ce_loss = torch.mean(torch.stack(val_ce_losses))
                metric = {"val/loss": val_loss.detach().item(), "val/ce_loss": avg_val_ce_loss.detach().item()}
                if kl_enabled and len(val_kl_losses) > 0:
                    avg_val_kl_loss = torch.mean(torch.stack(val_kl_losses))
                    metric["val/kl_loss"] = avg_val_kl_loss.detach().item()
                tracking.log(data=metric, step=global_step)
                if self.val_score_dataset is not None:
                    tracking.log(data=metric_dict, step=global_step)
            
            if global_step % self.config.trainer.test_freq == 0 and self.config.trainer.policy_eval:
                # Periodic eval uses refreshed eval model; all ranks participate
                self._sync_eval_model_from_fsdp_all_ranks()
                if rank == 0:
                    self._policy_eval_and_log(tracking, global_step)

            
            torch.distributed.barrier()

            if self.config.trainer.save_freq == -1:
                self.save_checkpoint(step=global_step)
            else:
                if global_step % self.config.trainer.save_freq == 0:
                    self.save_checkpoint(step=global_step)




@hydra.main(config_path="config", config_name="sft_trainer", version_base=None)
def main(config):
    local_rank, rank, world_size = initialize_global_process_group()

    device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
    dp_size = world_size // config.ulysses_sequence_parallel_size
    ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
    # build tokenizer and datasets first
    from verl.utils import hf_tokenizer

    local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
    tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
    if config.data.type == 'reasoning_gym':
        train_dataset, val_dataset = prepare_reasoning_gym_sft_dataset(config.data.reasoning_gym, tokenizer)
    else:
        train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
        val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)

    trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)

    trainer.fit()

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()


def create_sft_dataset(data_paths, data_config, tokenizer):
    """Create a dataset."""
    # build dataset
    # First check if a custom dataset class is specified
    if data_config.custom_cls.get("path", None):
        from verl.utils.import_utils import load_extern_type

        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
    # Default to single-turn dataset
    else:
        dataset_cls = SFTDataset

    # Create datasets based on the selected class
    dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
    return dataset

def convert_tensordict_to_dataproto(tensordict_data: TensorDict, 
                                   non_tensor_data: Dict[str, Any] = None,
                                   meta_info: Dict[str, Any] = None) -> DataProto:
    """
    Convert TensorDict to DataProto format for reward function evaluation.
    
    Args:
        tensordict_data: TensorDict containing the batch data
        non_tensor_data: Optional dictionary containing non-tensor data (numpy arrays)
        meta_info: Optional dictionary containing metadata
        
    Returns:
        DataProto: The converted DataProto object
    """
    
    # Extract tensors from TensorDict
    tensors = {}
    for key, value in tensordict_data.items():
        if isinstance(value, torch.Tensor):
            tensors[key] = value
    
    # Handle non-tensor data
    if non_tensor_data is None:
        non_tensor_data = {}
    
    # Convert non-tensor data to numpy arrays with dtype=object
    processed_non_tensors = {}
    for key, value in non_tensor_data.items():
        if isinstance(value, np.ndarray):
            processed_non_tensors[key] = value
        else:
            # Convert to numpy array with object dtype
            try:
                processed_non_tensors[key] = np.array(value, dtype=object)
            except (TypeError, ValueError):
                continue
    
    # Create DataProto using the from_dict method
    return DataProto.from_dict(
        tensors=tensors,
        non_tensors=processed_non_tensors,
        meta_info=meta_info or {}
    ) 

if __name__ == "__main__":
    main()